import math
from math import sqrt
import argparse
from pathlib import Path
import time

# torch

import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR

# vision imports

from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

# dalle classes and utils

from dalle_pytorch import distributed_utils
# from dalle_pytorch.dalle_pytorch_pn2 import DiscreteVAE
from dalle_pytorch.dalle_pytorch_ori import DiscreteVAE
# from dalle_pytorch.dalle_pytorch import DiscreteVAE
# from dalle_pytorch.dalle_pytorch_newest import DiscreteVAE
# from dalle_pytorch.dalle_pytorch_ae import DiscreteVAE

# argument parsing

from IPython import embed
import glob
from pytorch3d.io import load_ply
from torch.utils.data import Dataset
import os

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import comm as comm
import h5py
import numpy as np
from tensorboardX import SummaryWriter
import time
from datetime import timedelta

def normalize_points_torch(points):
    """Normalize point cloud

    Args:
        points (torch.Tensor): (batch_size, num_points, 3)

    Returns:
        torch.Tensor: normalized points

    """
    assert points.dim() == 3 and points.size(2) == 3
    centroid = points.mean(dim=1, keepdim=True)
    points = points - centroid
    norm, _ = points.norm(dim=2, keepdim=True).max(dim=1, keepdim=True)
    new_points = points / norm
    return new_points

def setup_ddp(gpu, args):
    dist.init_process_group(                                   
    	backend='nccl',      # backend='gloo',#                                    
   		init_method='env://',     
    	world_size=args.world_size,                              
    	rank=gpu)

    torch.manual_seed(0)
    torch.cuda.set_device(gpu)

# constants
def train(rank, args):
    if args.gpus > 1:
        setup_ddp(rank, args)
    
    IMAGE_SIZE = args.image_size
    IMAGE_PATH = args.image_folder

    EPOCHS = args.epochs
    BATCH_SIZE = args.batch_size
    LEARNING_RATE = args.learning_rate
    LR_DECAY_RATE = args.lr_decay_rate

    NUM_TOKENS = args.num_tokens
    NUM_LAYERS = args.num_layers
    NUM_RESNET_BLOCKS = args.num_resnet_blocks
    SMOOTH_L1_LOSS = args.smooth_l1_loss
    EMB_DIM = args.emb_dim
    HIDDEN_DIM = args.hidden_dim
    KL_LOSS_WEIGHT = args.kl_loss_weight

    STARTING_TEMP = args.starting_temp
    TEMP_MIN = args.temp_min
    ANNEAL_RATE = args.anneal_rate

    NUM_IMAGES_SAVE = args.num_images_save


    # data
    class PC_Dataset(Dataset):
        def __init__(self, path):
            self.data_dir = path
            self.data_list = glob.glob(os.path.join('/home/tiangel/datasets',self.data_dir, '*.ply'))
            self.len = len(self.data_list)
            self.do_aug = args.aug

        def __getitem__(self, index):
            pc = load_ply(self.data_list[index])
            points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
            if self.do_aug:
                scale = points.new(1).uniform_(0.9, 1.05)
                points[:, 0:3] *= scale
            return (points, pc[1])

        def __len__(self):
            return self.len
    class PC_Dataset_h5(Dataset):
        def __init__(self, path):
            # f = h5py.File('/home/tiangel/datasets/shapenet_plys_2048.h5', 'r')
            f = h5py.File(os.path.join('/home/tiangel/datasets/',path), 'r')
            self.data = np.array(f['data'])
            self.len = self.data.shape[0]
            self.do_aug = args.aug

        def __getitem__(self, index):
            # pc = load_ply(self.data_list[index])
            pc = torch.Tensor(self.data[index]).unsqueeze(0)
            points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
            if self.do_aug:
                scale = points.new(1).uniform_(0.9, 1.05)
                points[:, 0:3] *= scale
            return points

        def __len__(self):
            return self.len
    class PC_two_Dataset(Dataset):
        def __init__(self, path):
            self.data_dir = path
            f = h5py.File('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc.h5', 'r')
            self.data2 = np.array(f['data'])
            self.data_list = glob.glob(os.path.join('/home/tiangel/datasets',self.data_dir, '*.ply'))
            self.len1 = len(self.data_list)
            self.len = len(self.data_list) + self.data2.shape[0]
            self.do_aug = args.aug

        def __getitem__(self, index):
            if index < self.len1:
                pc = load_ply(self.data_list[index])
            else:
                pc = torch.Tensor(self.data2[index - self.len1]).unsqueeze(0)
            points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
            if self.do_aug:
                scale = points.new(1).uniform_(0.9, 1.05)
                points[:, 0:3] *= scale
            return points

        def __len__(self):
            return self.len

    class PC_Program_Dataset(Dataset):
        def __init__(self, path):
            self.data_dir = path
            f = h5py.File('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc.h5', 'r')
            self.data1 = np.array(f['data'])
            f2 = h5py.File('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_table_40k.h5', 'r')
            self.data2 = np.array(f2['data'])
            self.len1 = self.data1.shape[0]
            self.len = self.len1 + self.data2.shape[0]
            self.do_aug = args.aug

        def __getitem__(self, index):
            if index < self.len1:
                pc = torch.Tensor(self.data1[index]).unsqueeze(0)
            else:
                pc = torch.Tensor(self.data2[index - self.len1]).unsqueeze(0)
            points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
            if self.do_aug:
                scale = points.new(1).uniform_(0.9, 1.05)
                points[:, 0:3] *= scale
            return points
        def __len__(self):
            return self.len
    class PC_four_Dataset(Dataset):
        def __init__(self, path):
            self.data_dir = path
            self.data_list = glob.glob(os.path.join('/home/tiangel/datasets',self.data_dir, '*.ply'))
            self.data_list2 = glob.glob(os.path.join('/home/tiangel/datasets','abo_plys', '*.ply'))
            f = h5py.File('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc.h5', 'r')
            self.data2 = np.array(f['data'])
            f2 = h5py.File('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_table_40k.h5', 'r')
            self.data3 = np.array(f2['data'])

            self.len1 = len(self.data_list)
            self.len2 = self.data2.shape[0]
            self.len3 = self.data3.shape[0]
            self.len4 = len(self.data_list2)
            self.len = self.len1 + self.len2 + self.len3 + self.len4
            self.do_aug = args.aug

        def __getitem__(self, index):
            if index < self.len1:
                pc = load_ply(self.data_list[index])
            elif index >= self.len1 and index < self.len1 + self.len2:
                pc = torch.Tensor(self.data2[index - self.len1]).unsqueeze(0)
            elif index >= self.len1 + self.len2 and index < self.len1 + self.len2 + self.len3:
                pc = torch.Tensor(self.data3[index - self.len1 - self.len2]).unsqueeze(0)
            else:
                pc = load_ply(self.data_list2[index - self.len1 - self.len2 - self.len3])
            points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
            if self.do_aug:
                scale = points.new(1).uniform_(0.9, 1.05)
                points[:, 0:3] *= scale
            return points

        def __len__(self):
            return self.len
    class PC_four_h5_Dataset(Dataset):
        def __init__(self, path):
            self.data_dir = path
            # self.data_list = glob.glob(os.path.join('/home/tiangel/datasets',self.data_dir, '*.ply'))
            # self.data_list2 = glob.glob(os.path.join('/home/tiangel/datasets','abo_plys', '*.ply'))
            # f0 = h5py.File(os.path.join('/home/tiangel/datasets',self.data_dir + '.h5'), 'r')
            f0 = h5py.File('/home/tiangel/datasets/shapenet_plys.h5', 'r')
            self.data0 = np.array(f0['data'])
            f1 = h5py.File('/home/tiangel/datasets/abo_plys.h5', 'r')
            self.data1 = np.array(f1['data'])
            f2 = h5py.File('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_chair_40k.h5', 'r')
            self.data2 = np.array(f2['data'])
            f3 = h5py.File('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_table_40k.h5', 'r')
            self.data3 = np.array(f3['data'])

            self.len1 = self.data0.shape[0]
            self.len2 = self.data1.shape[0]
            self.len3 = self.data2.shape[0]
            self.len4 = self.data3.shape[0]
            self.len = self.len1 + self.len2 + self.len3 + self.len4
            self.do_aug = args.aug

        def __getitem__(self, index):
            if index < self.len1:
                # pc = load_ply(self.data_list[index])
                pc = torch.Tensor(self.data0[index]).unsqueeze(0)
            elif index >= self.len1 and index < self.len1 + self.len2:
                pc = torch.Tensor(self.data1[index - self.len1]).unsqueeze(0)
            elif index >= self.len1 + self.len2 and index < self.len1 + self.len2 + self.len3:
                pc = torch.Tensor(self.data2[index - self.len1 - self.len2]).unsqueeze(0)
            else:
                # pc = load_ply(self.data_list2[index - self.len1 - self.len2 - self.len3])
                pc = torch.Tensor(self.data3[index - self.len1 - self.len2 - self.len3]).unsqueeze(0)
            points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
            if self.do_aug:
                scale = points.new(1).uniform_(0.9, 1.05)
                points[:, 0:3] *= scale
            return points

        def __len__(self):
            return self.len
    class PC_text_Dataset(Dataset):
        def __init__(self, path):
            self.data_dir = path
            # self.data_list = glob.glob(os.path.join('/home/tiangel/datasets',self.data_dir, '*.ply'))
            # self.data_list2 = glob.glob(os.path.join('/home/tiangel/datasets','abo_plys', '*.ply'))
            # f0 = h5py.File(os.path.join('/home/tiangel/datasets',self.data_dir + '.h5'), 'r')
            f0 = h5py.File('/home/tiangel/datasets/shapenet_plys.h5', 'r')
            self.data0 = np.array(f0['data'])
            f1 = h5py.File('/home/tiangel/datasets/abo_plys.h5', 'r')
            self.data1 = np.array(f1['data'])

            self.len1 = self.data0.shape[0]
            self.len2 = self.data1.shape[0]
            self.len = self.len1 + self.len2 
            self.do_aug = args.aug

        def __getitem__(self, index):
            if index < self.len1:
                # pc = load_ply(self.data_list[index])
                pc = torch.Tensor(self.data0[index]).unsqueeze(0)
            elif index >= self.len1 and index < self.len1 + self.len2:
                pc = torch.Tensor(self.data1[index - self.len1]).unsqueeze(0)
            points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
            if self.do_aug:
                scale = points.new(1).uniform_(0.9, 1.05)
                points[:, 0:3] *= scale
            return points

        def __len__(self):
            return self.len
    # ds = PC_Dataset(IMAGE_PATH)
    # ds = PC_two_Dataset(IMAGE_PATH)
    class PC_five_h5_Dataset(Dataset):
        def __init__(self, path):
            self.data_dir = path
            # self.data_list = glob.glob(os.path.join('/home/tiangel/datasets',self.data_dir, '*.ply'))
            # self.data_list2 = glob.glob(os.path.join('/home/tiangel/datasets','abo_plys', '*.ply'))
            f0 = h5py.File(os.path.join('/home/tiangel/datasets',self.data_dir + '.h5'), 'r')
            self.data0 = np.array(f0['data'])
            f1 = h5py.File('/home/tiangel/datasets/abo_plys.h5', 'r')
            self.data1 = np.array(f1['data'])
            f2 = h5py.File('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc.h5', 'r')
            self.data2 = np.array(f2['data'])
            f3 = h5py.File('/home/tiangel/datasets/shapeprogram_data/train_shapes_pc_table_40k.h5', 'r')
            self.data3 = np.array(f3['data'])
            f4 = h5py.File('/home/tiangel/datasets/shapenet_plys.h5', 'r')
            self.data4 = np.array(f4['data'])

            self.len1 = self.data0.shape[0]
            self.len2 = self.data1.shape[0]
            self.len3 = self.data2.shape[0]
            self.len4 = self.data3.shape[0]
            self.len5 = self.data4.shape[0]
            self.len = self.len1 + self.len2 + self.len3 + self.len4 + self.len5
            self.do_aug = args.aug

        def __getitem__(self, index):
            if index < self.len1:
                # pc = load_ply(self.data_list[index])
                pc = torch.Tensor(self.data0[index]).unsqueeze(0)
            elif index >= self.len1 and index < self.len1 + self.len2:
                pc = torch.Tensor(self.data1[index - self.len1]).unsqueeze(0)
            elif index >= self.len1 + self.len2 and index < self.len1 + self.len2 + self.len3:
                pc = torch.Tensor(self.data2[index - self.len1 - self.len2]).unsqueeze(0)
            elif index >= self.len1 + self.len2 + self.len3 and index < self.len1 + self.len2 + self.len3 + self.len4:
                pc = torch.Tensor(self.data3[index - self.len1 - self.len2 - self.len3]).unsqueeze(0)
            else:
                # pc = load_ply(self.data_list2[index - self.len1 - self.len2 - self.len3])
                pc = torch.Tensor(self.data3[index - self.len1 - self.len2 - self.len3 - self.len4]).unsqueeze(0)
            points = normalize_points_torch(pc[0].unsqueeze(0)).squeeze()
            if self.do_aug:
                scale = points.new(1).uniform_(0.9, 1.05)
                points[:, 0:3] *= scale
            return points

        def __len__(self):
            return self.len
    if args.dataset == 'full':
        ds = PC_five_h5_Dataset(IMAGE_PATH)
    elif args.dataset == 'shapenet' and args.category =='shapenet':
        ds = PC_Dataset_h5('shapenet_plys_2048.h5')
    elif args.dataset == 'shapenet' and args.category =='airplane':
        ds = PC_Dataset_h5('shapenet_plys_2048_airplane.h5')
    elif args.dataset == 'shapenet' and args.category =='car':
        ds = PC_Dataset_h5('shapenet_plys_2048_car.h5')
    elif args.dataset == 'shapenet' and args.category =='chair':
        ds = PC_Dataset_h5('shapenet_plys_2048_chair.h5')
    elif args.dataset == 'four':
        ds = PC_four_h5_Dataset(IMAGE_PATH)
    elif args.dataset == 'program':
        ds = PC_Program_Dataset(IMAGE_PATH)
    elif args.dataset == 'completion':
        ds = PC_Dataset_h5('completion_train_data.h5')
    elif args.dataset == 'text':
        ds = PC_text_Dataset('completion_train_data.h5')
    else:
        NameError('wrong dataset type')
    # ds = PC_four_h5_Dataset(IMAGE_PATH)
    # ds = PC_Dataset(IMAGE_PATH)
    assert len(ds) > 0, 'folder does not contain any images'
    # train_sampler = torch.utils.data.distributed.DistributedSampler(
                # ds, shuffle=True, num_replicas=args.world_size, rank=rank)
    # dl = DataLoader(ds, BATCH_SIZE, sampler=train_sampler, drop_last=True)
    if args.gpus > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
                    ds, shuffle=True, num_replicas=args.gpus, rank=rank)
        dl = DataLoader(ds, BATCH_SIZE, sampler=train_sampler, drop_last=True)
    else:
        dl = DataLoader(ds, BATCH_SIZE, drop_last=True)

    vae_params = dict(
        image_size = IMAGE_SIZE,
        num_layers = NUM_LAYERS,
        num_tokens = NUM_TOKENS,
        codebook_dim = EMB_DIM,
        hidden_dim   = HIDDEN_DIM,
        num_resnet_blocks = NUM_RESNET_BLOCKS,
        dim1 = args.dim1,
        dim2 = args.dim2,
        radius = args.radius,
        final_points = args.final_points,
        final_dim = args.final_dim,
        vae_type = args.vae_type,
        vae_encode_type = args.vae_encode_type,
    )

    vae = DiscreteVAE(
        **vae_params,
        smooth_l1_loss = SMOOTH_L1_LOSS,
        kl_div_loss_weight = KL_LOSS_WEIGHT
    ).train()
    opt = Adam(vae.parameters(), lr = LEARNING_RATE)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = opt, T_max = EPOCHS*int(len(ds)/BATCH_SIZE/args.gpus))

    if args.resume:
        load_obj = torch.load(os.path.join('./outputs/vae_models','vae'+args.save_name+'-cpu.pt'))
        weights, resume_epoch, opt_state, sche_state = load_obj['weights'], load_obj['epoch'], load_obj.get('opt_state'), load_obj.get('scheduler_state')
        vae.load_state_dict(weights)
        vae.cuda()
        opt.load_state_dict(opt_state)
        sched.load_state_dict(sche_state)
    else:
        vae.cuda()
        resume_epoch = 0


    distributed = comm.get_world_size() > 1
    print('Distributed:', distributed)

    if args.gpus > 1:
        vae = DistributedDataParallel(
                vae, device_ids=[rank], broadcast_buffers=False
            )


    # optimizer



    # if distr_backend.is_root_worker():
    # if comm.is_main_process():
    #     # weights & biases experiment tracking

    #     import wandb

    #     model_config = dict(
    #         num_tokens = NUM_TOKENS,
    #         smooth_l1_loss = SMOOTH_L1_LOSS,
    #         num_resnet_blocks = NUM_RESNET_BLOCKS,
    #         kl_loss_weight = KL_LOSS_WEIGHT
    #     )

    #     run = wandb.init(
    #         project = 'dalle_train_vae',
    #         job_type = 'train_model',
    #         config = model_config
    #     )

    # using_deepspeed_sched = False
    # # Prefer scheduler in `deepspeed_config`.
    # if distr_sched is None:
    #     distr_sched = sched
    # elif using_deepspeed:
    #     # We are using a DeepSpeed LR scheduler and want to let DeepSpeed
    #     # handle its scheduling.
    #     using_deepspeed_sched = True

    def save_model(path, epoch, gpus):
        save_obj = {
            'hparams': vae_params,
            'epoch': epoch,
        }

        if gpus > 1:
            save_obj = {
                **save_obj,
                'weights': vae.module.state_dict(),
                'opt_state': opt.state_dict(),
                'scheduler_state': (sched.state_dict() if sched else None)
            }
        else:
            save_obj = {
                **save_obj,
                'weights': vae.state_dict(),
                'opt_state': opt.state_dict(),
                'scheduler_state': (sched.state_dict() if sched else None)
            }

        torch.save(save_obj, path)

    # starting temperature

    if comm.is_main_process() and args.tensorboard_flag:
        writer = SummaryWriter(os.path.join('./outputs/vae_outputs', 'test'+args.save_name))
    global_step = 0
    temp = STARTING_TEMP
    from scipy.signal import savgol_filter
    train_res_perplexity = []
    for epoch in range(resume_epoch, EPOCHS):
        start_time = time.time()
        for i, images in enumerate(dl):
            images = images.cuda()

            # loss, recons, perplexity = vae(
                # images,
                # return_loss = True,
                # return_recons = True,
                # temp = temp,
                # epoch = epoch,
            # )
            loss, cd_loss, emd_loss, _, _, perplexity = vae(
                images,
                return_loss = True,
                return_recons = True,
                return_detailed_loss = True,
                temp = temp,
                epoch = epoch,
            )

            opt.zero_grad()
            loss.backward()
            opt.step()
            sched.step()
            if comm.is_main_process():
                train_res_perplexity.append(perplexity.cpu().numpy())
                if args.tensorboard_flag:
                    writer.add_scalar('train_loss', loss.item(), epoch*len(dl)+i)
                    writer.add_scalar('cd_loss', cd_loss.item(), epoch*len(dl)+i)
                    writer.add_scalar('emd_loss', emd_loss.item(), epoch*len(dl)+i)
                    writer.add_scalar('perplexity', perplexity.item(), epoch*len(dl)+i)


            logs = {}

            if i % 100 == 0:
                # if comm.is_main_process():
                #     k = NUM_IMAGES_SAVE

                #     with torch.no_grad():
                #         codes = vae.get_codebook_indices(images[:k])
                #         hard_recons = vae.decode(codes)

                #     images, recons = map(lambda t: t[:k], (images, recons))
                #     images, recons, hard_recons, codes = map(lambda t: t.detach().cpu(), (images, recons, hard_recons, codes))
                #     images, recons, hard_recons = map(lambda t: make_grid(t.float(), nrow = int(sqrt(k)), normalize = True, range = (-1, 1)), (images, recons, hard_recons))

                #     logs = {
                #         **logs,
                #         'sample images':        wandb.Image(images, caption = 'original images'),
                #         'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
                #         'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
                #         'codebook_indices':     wandb.Histogram(codes),
                #         'temperature':          temp
                #     }

                    # wandb.save('./vae.pt')
                # save_model(f'./vae.pt')
                if comm.is_main_process():
                    save_model(f'./outputs/vae_models/vae'+args.save_name+'.pt', epoch, args.gpus)

                # temperature anneal

                # temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN)
                # print('temp:',temp)
                # gradually decrese it from 5 to 0.05 over 5 epoches

                # lr decay

                # Do not advance schedulers from `deepspeed_config`.

            # Collective loss, averaged
            # avg_loss = distr_backend.average_all(loss)

            if comm.is_main_process():
                if i % 10 == 0:
                    lr = sched.get_last_lr()[0]
                    print(epoch, i, f'lr - {lr:6f} loss - {loss}')

                    logs = {
                        **logs,
                        'epoch': epoch,
                        'iter': i,
                        'loss': loss,
                        'lr': lr
                    }

                # wandb.log(logs)
            global_step += 1

        # if distr_backend.is_root_worker():
            # save trained model to wandb as an artifact every epoch's end

            # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
            # model_artifact.add_file('vae.pt')
            # run.log_artifact(model_artifact)
        elapsed_time_secs = time.time() - start_time
        msg = "1 epoch took: %s secs" % timedelta(seconds=round(elapsed_time_secs))
        print(msg)

    if comm.is_main_process():
        # save final vae and cleanup

        # save_model('./vae-final.pt')
        save_model('./outputs/vae_models/vae-final'+args.save_name+'.pt', EPOCHS, args.gpus)
        train_res_perplexity_smooth = savgol_filter(train_res_perplexity, 201, 7)
        import matplotlib.pyplot as plt
        f = plt.figure(figsize=(8,8))
        plt.plot(train_res_perplexity_smooth)
        plt.savefig('./outputs/visu_perplexity/perplexity'+args.save_name+'.png',dpi=600)
        # wandb.save('./vae-final.pt')

        # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
        # model_artifact.add_file('vae-final.pt')
        # run.log_artifact(model_artifact)

        # wandb.finish()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--image_folder', type = str, required = True,
                        help='path to your folder of images for learning the discrete VAE and its codebook')

    parser.add_argument('--image_size', type = int, required = False, default = 128,
                        help='image size')

    parser = distributed_utils.wrap_arg_parser(parser)

    train_group = parser.add_argument_group('Training settings')

    train_group.add_argument('--epochs', type = int, default = 20, help = 'number of epochs')

    train_group.add_argument('--batch_size', type = int, default = 8, help = 'batch size')

    train_group.add_argument('--learning_rate', type = float, default = 2e-3, help = 'learning rate')

    train_group.add_argument('--lr_decay_rate', type = float, default = 0.98, help = 'learning rate decay')

    train_group.add_argument('--starting_temp', type = float, default = 1., help = 'starting temperature')

    train_group.add_argument('--temp_min', type = float, default = 0.05, help = 'minimum temperature to anneal to')

    train_group.add_argument('--anneal_rate', type = float, default = 2e-4, help = 'temperature annealing rate')

    train_group.add_argument('--num_images_save', type = int, default = 4, help = 'number of images to save')

    model_group = parser.add_argument_group('Model settings')

    model_group.add_argument('--num_tokens', type = int, default = 8192, help = 'number of image tokens')

    model_group.add_argument('--num_layers', type = int, default = 3, help = 'number of layers (should be 3 or above)')

    model_group.add_argument('--num_resnet_blocks', type = int, default = 2, help = 'number of residual net blocks')

    model_group.add_argument('--smooth_l1_loss', dest = 'smooth_l1_loss', action = 'store_true')

    model_group.add_argument('--emb_dim', type = int, default = 512, help = 'embedding dimension')

    model_group.add_argument('--hidden_dim', type = int, default = 256, help = 'hidden dimension')

    model_group.add_argument('--dim1', type = int, default = 16, help = 'hidden dimension')

    model_group.add_argument('--dim2', type = int, default = 32, help = 'hidden dimension')

    model_group.add_argument('--final_points', type = int, default = 16, help = 'hidden dimension')

    model_group.add_argument('--gpus', type = int, default = 1, help = 'hidden dimension')

    model_group.add_argument('--radius', type = float, default = 0.4, help = 'hidden dimension')

    model_group.add_argument('--kl_loss_weight', type = float, default = 0., help = 'KL loss weight')

    model_group.add_argument('--final_dim', type = int, default = 2048, help = 'hidden dimension')

    model_group.add_argument('--save_name', type = str, default = '1', help = 'KL loss weight')

    model_group.add_argument('--aug', type = bool, default = True, help = 'KL loss weight')

    model_group.add_argument('--resume', type = bool, default = False, help = 'KL loss weight')

    model_group.add_argument('--dataset', type = str, default = 'four', help = 'inverse feeding')

    model_group.add_argument('--category', type = str, default = 'shapenet', help = 'KL loss weight')

    model_group.add_argument('--vae_type', type = int, default = 5, help = 'KL loss weight')

    model_group.add_argument('--vae_encode_type', type = int, default = 4, help = 'KL loss weight')

    model_group.add_argument('--tensorboard_flag', type = bool, default = False, help = 'KL loss weight')

    model_group.add_argument('--port', type = str, default = '12357', help = 'port for parallel')

    args = parser.parse_args()

    if args.gpus == 1:
        train(args.gpus, args)
    else:
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = args.port
        args.world_size = args.gpus
        mp.spawn(train, nprocs=args.gpus, args=(args,))